Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OneClassClassifier model #3501

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

jeffpicard
Copy link
Contributor

This PR adds OneClassClassifier to flair.models for #3496.

The task, usage, and architecture are described in the class docstring.

The architecture is inspired by papers such as Anomaly Detection Using Autoencoders with Nonlinear Dimensionality Reduction. While this doesn't achieve state of the art, or implement improvements like adding noise, I thought I'd see if you're interested in it, as it's a new task formulation that works, and might be useful to others.

The interface requires users to set the threshold explicitly... not sure if there's a cleaner way to hook that in to happen automatically after training completes.

Here's a short script demonstrating its usage separating IMDB from STACKOVERFLOW:

import json

from torch.utils.data import Subset

from flair.data import Sentence, Corpus
from flair.datasets import IMDB, STACKOVERFLOW
from flair.embeddings import TransformerWordEmbeddings
from flair.models.one_class_classification_model import OneClassClassifier
from flair.trainers import ModelTrainer

label_type = "sentiment"
embeddings = TransformerWordEmbeddings(
    model="xlm-roberta-base",
    is_document_embedding=True,
)

# Train on just IMDB, infer IMDB vs STACKOVERFLOW
corpus = Corpus(
    train=[x for x in Subset(IMDB().train, range(500))],
    test=[x for x in Subset(IMDB().test, range(250))] + [Sentence(x.text).add_label(typename=label_type, value="<unk>") for x in Subset(STACKOVERFLOW().test, range(250))]
)

label_dictionary = corpus.make_label_dictionary(label_type)
model = OneClassClassifier(embeddings, label_dictionary, label_type=label_type)

trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp")

threshold = model.calculate_threshold(corpus.dev)
model.threshold = threshold
result = model.evaluate(corpus.test, gold_label_type=label_type)
print(json.dumps(result.classification_report, indent=2))

prints

{
  "POSITIVE": {
    "precision": 1.0,
    "recall": 1.0,
    "f1-score": 1.0,
    "support": 250.0
  },
  "<unk>": {
    "precision": 1.0,
    "recall": 1.0,
    "f1-score": 1.0,
    "support": 250.0
  },
  "accuracy": 1.0,
[...]
}

Thanks for any time you're willing to put into considering this :) !

@alanakbik
Copy link
Collaborator

@jeffpicard thanks for the PR!

@elenamer can you take a look?

@jeffpicard jeffpicard force-pushed the 3496-one-class-classifier branch from 5a4204c to 65459d3 Compare August 6, 2024 07:26
@jeffpicard
Copy link
Contributor Author

Many thanks for the review! I've squashed in a commit with your requested changes (Implement mini_batch_size and verbose; Rename loss). @elenamer would you be willing to take another look please?

Copy link
Collaborator

@elenamer elenamer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the OneClassClassifier, looks good to me now!

@jeffpicard jeffpicard force-pushed the 3496-one-class-classifier branch from 65459d3 to 5ab2bfc Compare August 7, 2024 18:47
@jeffpicard
Copy link
Contributor Author

(CI was failing with errors that looked unrelated to this branch so I clicked the "rebase" button in the UI)

Copy link
Collaborator

@alanakbik alanakbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

I tested with your script and also the following script to see if I can find some outliers in the TREC_6 dataset:

embeddings = TransformerDocumentEmbeddings(
    model="distilbert-base-uncased",
)

# Train on TREC_6
imdb = TREC_6()
label_type = "question_class"

# Identify outliers of the class DESC
corpus = Corpus(
    train=[x for x in imdb.train if x.get_label(label_type).value == "ENTY"],
    test=[x for x in imdb.test if x.get_label(label_type).value == "ENTY"],
)

print(corpus)

label_dictionary = corpus.make_label_dictionary(label_type)
model = OneClassClassifier(embeddings, label_dictionary, label_type=label_type)

trainer = ModelTrainer(model, corpus)
trainer.fine_tune("resources/taggers/outlier", mini_batch_size=8)

threshold = model.calculate_threshold(corpus.dev, quantile=0.95)
model.threshold = threshold

result = model.evaluate(corpus.test, gold_label_type=label_type, out_path="predictions.txt")
print(json.dumps(result.classification_report, indent=2))

Some thoughts:

  • Having to set a threshold afterwards breaks with the default way of training/testing models in Flair. The trainer by default produces a dev.tsv and a test.tsv and prints evaluation metrics during/after training. However, without the threshold set during training, these outputs make no sense, which might confuse users.
  • Have you considered modeling this as a regression task instead of a classification task? Then users would not need to set a threshold. Also, this would allow users to do things like print out the 10 biggest outliers which could be more useful than experimenting with different thresholds.
  • The name of the class (OneClassClassifier) currently does not explain what it does. How about something like OutlierDetection?
  • The model currently is limited to datasets consisting of a single class. This means users will always need to first create a dataset like in your example snippet. Is it possible to do outlier detection for multiple classes at once? Or is this problematic since each class would need a separate encoder/decoder network?

@jeffpicard
Copy link
Contributor Author

Hi @alanakbik. I'm extremely sorry I took so long to reply, and many thanks for your thoughts.

To your points,

  1. Thanks for calling out the eval metrics confusion. I think this could be fixed by adding a plugin that calculates thresholds inside after_training_epoch. It would be similar to what's being done in this PR. Requiring a user to remember to use a plugin isn't ideal. Outside the scope here, but the plugin awkwardness might go away if flair had the ability for the model to handle trainer events similar to a plugin, but without a plugin.
  2. Hmm, those are great properties that emerge if this is modeled as a regression. My worry is that in many anomaly detection datasets, it would be hard to to come up with a continuous label rather than a discrete one to serve as the target. Printing the top 10 can still be achieved in a classifier with return_probabilities_for_all_classes=True and sorting by the probability perhaps. I haven't personally connected how a regression would work end-to-end in my head, but if you have and can help, that sounds great.
  3. What do you think about AnomalyDetection rather than OutlierDetection? Probably doesn't matter, but the way these words get used sometimes have more specific meaning. Outlier Detector meaning the training set has both inliers and outliers. Novelty Detection meaning the training set only has inliers. Anomaly Detection meaning either Outlier or Novelty detection. (e.g. sklearn's docs). The algorithm here only really applies to Novelty Detection, but maybe the future of this class involves more parameters specifying which algo to use.
  4. Thanks, I also thought throwing an exception for multi-class corpora is unexpected and not ideal generally. I think multi-class could be handled by adding another dimension to the tensors containing the extra encoder/decoder networks. I'll give that a try. It might also be possible to do "<unk>" vs any-class-in-training with a single encoder/decoder network.

Two more points I'm wondering what you (or others) think about:

  • flair/nn/decoder.py: Moving the embedding -> score logic into a decoder similar to PrototypicalDecoder rather than this class. This would allow the autoencoder technique to be reused in other classes (e.g. Regressors, TextPairClassifier, TextTripleClassifier), or swapped out in this class.
  • Anomaly Detection inside DefaultClassifier rather than this separate class. I think Anomaly Detection can be viewed as basically DefaultClassifier, except able to return "<unk>". DefaultClassifer could get a parameter, novelty: bool and the implementation would change to be something like:
    # in predict()
    if self.multi_label:
        sigmoided = ...
    elif self.novelty:
        # add <"unk"> class
    else:
        softmax = ...
    

Altogether this could look like

anomaly_detector = TextClassifier(
    novelty=True
)
trainer.fine_tune(
    plugins=[ThresholdCalculationPlugin()],
)

I'm sorry this got so long! Focusing in on some yes/no questions that I think can be decided independently:

  • ThresholdPlugin
  • flair/nn/decoder.py
  • novelty=True option for DefaultClassifier
    • This is the only thing that modifies existing code

@alanakbik
Copy link
Collaborator

Hello @jeffpicard also from my side sorry it took so long to reply! Thanks for the many ideas and input!

  • ThresholdPlugin -> yes, I think that's a good idea. Especially in combination with your "out-of-scope" idea. I think it would actually not be so difficult to give models the ability to contain default plugins. The ModelTrainer would simply need to check if the model contains any plugins and add those to the training by default. So, having a ThresholdPlugin now would also prepare (and motivate) a future step of having Plugins at models themselves.

  • decoder and support in DefaultClassifier -> yes, I think it would be great to have this ability for other types of tasks as well. For instance, @elenamer just added the NER_NOISEBENCH dataset which we are using for noise-robust learning research on the NER level (see our paper), so it would be interesting to see if this approach could be used for word-level predictions. I guess the parameter novelty=True would be hard to parse for users, so maybe something more descriptive such as use_as_anomaly_detector or so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants